{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "# 使用 Relay Visualizer 可视化 Relay\n",
        "\n",
        "**原作者**: [Chi-Wei Wang](https://github.com/chiwwang)\n",
        "\n",
        "Relay IR 模块可以包含很多运算。尽管单个运算通常很容易理解，但将它们放在一起可能会导致复杂的、难以阅读的 graph。随着优化传递（passes）的出现，情况可能会变得更糟。\n",
        "\n",
        "这个实用程序将 IR 模块可视化为节点和边。它定义了一组接口，包括 parser、plotter(renderer)、graph、node 和 edges。\n",
        "提供了默认 parser。用户可以实现自己的 renderer 来渲染 graph。\n",
        "\n",
        "在这里，使用 renderer 在文本形式中渲染 graph。它是轻量级的、类似 AST 的可视化工具，灵感来自 [clang ast-dump](https://clang.llvm.org/docs/IntroductionToTheClangAST.html)。下面将介绍如何通过接口类实现定制的 parser 和 renderer。\n",
        "\n",
        "更多细节见：{py:mod}`tvm.contrib.relay_viz`。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 1,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import tvm\n",
        "from tvm import relay\n",
        "from tvm.contrib.relay_viz import RelayVisualizer, DotPlotter, DotVizParser\n",
        "from tvm.contrib.relay_viz.interface import (\n",
        "    VizEdge,\n",
        "    VizNode,\n",
        "    VizParser,\n",
        ")\n",
        "from tvm.contrib.relay_viz.terminal import (\n",
        "    TermGraph,\n",
        "    TermPlotter,\n",
        "    TermVizParser,\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## 定义具有多个 `GlobalVar` 的 Relay IR 模块\n",
        "\n",
        "构建包含多个 `GlobalVar` 的示例 IR 模块。定义 `add` 函数，并在 `main` 函数中调用它。\n",
        "\n",
        "```{rubric} 创建 add 算子及其函数\n",
        "```"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 2,
      "metadata": {},
      "outputs": [],
      "source": [
        "data = relay.var(\"data\")\n",
        "bias = relay.var(\"bias\")\n",
        "add_op = relay.add(data, bias)\n",
        "add_func = relay.Function([data, bias], add_op)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "查看算子和函数："
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 3,
      "metadata": {},
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "算子：\n",
            "free_var %data;\n",
            "free_var %bias;\n",
            "add(%data, %bias)\n",
            "====================\n",
            "函数：\n",
            "fn (%data, %bias) {\n",
            "  add(%data, %bias)\n",
            "}\n"
          ]
        }
      ],
      "source": [
        "print(f\"算子：\\n{add_op}\")\n",
        "print(\"=\"*20)\n",
        "print(f\"函数：\\n{add_func}\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 4,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "add_gvar = relay.GlobalVar(\"AddFunc\")\n",
        "input0 = relay.var(\"input0\")\n",
        "input1 = relay.var(\"input1\")\n",
        "input2 = relay.var(\"input2\")\n",
        "add_01 = relay.Call(add_gvar, [input0, input1])\n",
        "add_012 = relay.Call(add_gvar, [input2, add_01])\n",
        "main_func = relay.Function([input0, input1, input2], add_012)\n",
        "main_gvar = relay.GlobalVar(\"main\")\n",
        "\n",
        "mod = tvm.IRModule({main_gvar: main_func,\n",
        "                    add_gvar: add_func})"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## 在终端上使用 Relay Visualizer 渲染 graph\n",
        "\n",
        "终端是类似 clang AST-dump 的文本形式显示 Relay IR 模块。\n",
        "\n",
        "看到 ``main`` 和 ``AddFunc`` 函数。``AddFunc`` 在 ``main`` 函数中调用两次。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 5,
      "metadata": {
        "collapsed": false
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "@main([Var(input0), Var(input1), Var(input2)])\n",
            "`--Call \n",
            "   |--GlobalVar AddFunc\n",
            "   |--Var(Input) name_hint: input2\n",
            "   `--Call \n",
            "      |--GlobalVar AddFunc\n",
            "      |--Var(Input) name_hint: input0\n",
            "      `--Var(Input) name_hint: input1\n",
            "@AddFunc([Var(data), Var(bias)])\n",
            "`--Call \n",
            "   |--add \n",
            "   |--Var(Input) name_hint: data\n",
            "   `--Var(Input) name_hint: bias\n"
          ]
        }
      ],
      "source": [
        "viz = RelayVisualizer(mod)\n",
        "viz.render()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## 为感兴趣的 Relay 类型定制解析器\n",
        "\n",
        "有时想要强调感兴趣的信息，或者针对特定的用法以不同的方式分析事物。只要遵循接口，就可以提供定制的解析器。\n",
        "\n",
        "这里演示如何自定义 ``relay.var`` 的解析器。\n",
        "\n",
        "需要实现抽象接口 {py:class}`tvm.contrib.relay_viz.interface.VizParser`。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 6,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "class YourAwesomeParser(VizParser):\n",
        "    def __init__(self):\n",
        "        self._delegate = TermVizParser()\n",
        "\n",
        "    def get_node_edges(\n",
        "        self,\n",
        "        node: relay.Expr,\n",
        "        relay_param: dict[str, tvm.runtime.NDArray],\n",
        "        node_to_id: dict[relay.Expr, str],\n",
        "    ) -> tuple[VizNode | None, list[VizEdge]]:\n",
        "\n",
        "        if isinstance(node, relay.Var):\n",
        "            node = VizNode(node_to_id[node], \"AwesomeVar\", f\"name_hint {node.name_hint}\")\n",
        "            # no edge is introduced. So return an empty list.\n",
        "            return node, []\n",
        "\n",
        "        # delegate other types to the other parser.\n",
        "        return self._delegate.get_node_edges(node, relay_param, node_to_id)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "将解析器和感兴趣的渲染程序传递给可视化工具。\n",
        "\n",
        "这里只是终端（terminal）渲染器。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 7,
      "metadata": {
        "collapsed": false
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "@main([Var(input0), Var(input1), Var(input2)])\n",
            "`--Call \n",
            "   |--GlobalVar AddFunc\n",
            "   |--AwesomeVar name_hint input2\n",
            "   `--Call \n",
            "      |--GlobalVar AddFunc\n",
            "      |--AwesomeVar name_hint input0\n",
            "      `--AwesomeVar name_hint input1\n",
            "@AddFunc([Var(data), Var(bias)])\n",
            "`--Call \n",
            "   |--add \n",
            "   |--AwesomeVar name_hint data\n",
            "   `--AwesomeVar name_hint bias\n"
          ]
        }
      ],
      "source": [
        "viz = RelayVisualizer(mod, {}, TermPlotter(), YourAwesomeParser())\n",
        "viz.render()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## 定制 Graph 和 Plotter\n",
        "\n",
        "除了解析器，还可以通过实现抽象类 {py:class}`tvm.contrib.relay_viz.interface.VizGraph` 和 {py:class}`tvm.contrib.relay_viz.interface.Plotter` 来定制 graph 和渲染器。\n",
        "\n",
        "这里，重写了 ``terminal.py`` 中定义的 ``TermGraph``，以方便演示。在 ``AwesomeVar`` 上面添加了钩子，并让 ``TermPlotter`` 使用新类。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 8,
      "metadata": {
        "collapsed": false
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "@main([Var(input0), Var(input1), Var(input2)])\n",
            "`--Call \n",
            "   |--GlobalVar AddFunc\n",
            "   |--AwesomeVar name_hint input2\n",
            "   |  `--double AwesomeVar \n",
            "   `--Call \n",
            "      |--GlobalVar AddFunc\n",
            "      |--AwesomeVar name_hint input0\n",
            "      |  `--double AwesomeVar \n",
            "      `--AwesomeVar name_hint input1\n",
            "         `--double AwesomeVar \n",
            "@AddFunc([Var(data), Var(bias)])\n",
            "`--Call \n",
            "   |--add \n",
            "   |--AwesomeVar name_hint data\n",
            "   |  `--double AwesomeVar \n",
            "   `--AwesomeVar name_hint bias\n",
            "      `--double AwesomeVar \n"
          ]
        }
      ],
      "source": [
        "class AwesomeGraph(TermGraph):\n",
        "    def node(self, viz_node):\n",
        "        # add the node first\n",
        "        super().node(viz_node)\n",
        "        # if it's AwesomeVar, duplicate it.\n",
        "        if viz_node.type_name == \"AwesomeVar\":\n",
        "            duplicated_id = f\"duplicated_{viz_node.identity}\"\n",
        "            duplicated_type = \"double AwesomeVar\"\n",
        "            super().node(VizNode(duplicated_id, duplicated_type, \"\"))\n",
        "            # connect the duplicated var to the original one\n",
        "            super().edge(VizEdge(duplicated_id, viz_node.identity))\n",
        "\n",
        "\n",
        "# override TermPlotter to use `AwesomeGraph` instead\n",
        "class AwesomePlotter(TermPlotter):\n",
        "    def create_graph(self, name):\n",
        "        self._name_to_graph[name] = AwesomeGraph(name)\n",
        "        return self._name_to_graph[name]\n",
        "\n",
        "\n",
        "viz = RelayVisualizer(mod, {}, AwesomePlotter(), YourAwesomeParser())\n",
        "viz.render()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "也可以渲染为："
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 9,
      "metadata": {},
      "outputs": [],
      "source": [
        "from tvm.contrib import relay_viz\n",
        "from tvm.relay.testing import resnet\n",
        "\n",
        "mod, param = resnet.get_workload(num_layers=18)\n",
        "# graphviz attributes\n",
        "graph_attr = {\"color\": \"red\"}\n",
        "node_attr = {\"color\": \"blue\"}\n",
        "edge_attr = {\"color\": \"black\"}\n",
        "\n",
        "# VizNode is passed to the callback.\n",
        "# We want to color NCHW conv2d nodes. Also give Var a different shape.\n",
        "def get_node_attr(node):\n",
        "    if \"nn.conv2d\" in node.type_name and \"NCHW\" in node.detail:\n",
        "        return {\n",
        "            \"fillcolor\": \"green\",\n",
        "            \"style\": \"filled\",\n",
        "            \"shape\": \"box\",\n",
        "        }\n",
        "    if \"Var\" in node.type_name:\n",
        "        return {\"shape\": \"ellipse\"}\n",
        "    return {\"shape\": \"box\"}\n",
        "\n",
        "\n",
        "# Create plotter and pass it to viz. Then render the graph.\n",
        "dot_plotter = DotPlotter(\n",
        "    graph_attr=graph_attr,\n",
        "    node_attr=node_attr,\n",
        "    edge_attr=edge_attr,\n",
        "    get_node_attr=get_node_attr)\n",
        "\n",
        "viz = RelayVisualizer(\n",
        "    mod,\n",
        "    relay_param=param,\n",
        "    plotter=dot_plotter,\n",
        "    parser=DotVizParser())\n",
        "viz.render(\"hello\")  # 保存到 PDF"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## 小结\n",
        "\n",
        "本教程演示了 Relay Visualizer 及其定制的用法。\n",
        "\n",
        "{py:class}`tvm.contrib.relay_viz.RelayVisualizer` 由定义在 ``interface.py`` 中的接口组成。\n",
        "\n",
        "它的目标是快速 look-then-fix 迭代。构造函数参数的目的是简单，而定制仍然可以通过一组接口类进行。"
      ]
    }
  ],
  "metadata": {
    "interpreter": {
      "hash": "aa67ff675248b5ab29dcd2f00c1422844307085c8ca7c8ce7eddecd21b9c2975"
    },
    "kernelspec": {
      "display_name": "Python 3.10.4 ('mxnetx')",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.10.4"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
